"""通过shp面文件裁剪tif"""
from osgeo import gdal, osr, ogr, gdal_array
import os
import shapefile
import time
from pyproj import Proj, transform
import numpy as np
from tqdm import tqdm

import pandas  as pd

def ClipMinimumRectangle(path, outputpath, ignoreValue):
    """裁剪最小外包矩形, 仅限2维矩阵,ignoreValue是列表"""
    im_data, im_geotrans, im_proj = readGeoTIFF(path)
    for data in ignoreValue:
        im_data[im_data == data] = ignoreValue[0]
    Colindex = np.where(np.all(im_data == im_data[0, :], axis=0) == False)[0]
    Rowindex = np.where(np.all(im_data == im_data[0, :], axis=1) == False)[0]

    start_x, end_x = Colindex[0], Colindex[-1]+1
    start_y, end_y = Rowindex[0], Rowindex[-1]+1


    im_data = im_data[start_y:end_y, start_x:end_x]

    im_geotrans = list(im_geotrans)
    im_geotrans[0] = im_geotrans[0] + start_x * im_geotrans[1]
    im_geotrans[3] = im_geotrans[3] + start_y * im_geotrans[5]

    CreateGeoTiff(outputpath, im_data, tuple(im_geotrans), im_proj)


def get_tif_info(tif_path):
    if tif_path.endswith('.tif') or tif_path.endswith('.TIF'):
        dataset = gdal.Open(tif_path)
        pcs = osr.SpatialReference()
        pcs.ImportFromWkt(dataset.GetProjection())
        gcs = pcs.CloneGeogCS()
        extend = dataset.GetGeoTransform()
        shape = (dataset.RasterYSize, dataset.RasterXSize)
    else:
        raise "Unsupported file format"
    img = dataset.GetRasterBand(1).ReadAsArray()  # (height, width)
    return img, dataset, gcs, pcs, extend, shape


def clip_data(path, shp_path, out_path):
    # 将数据源作为gdal_array载入

    # 同时载入gdal库的图片从而获取geotransform
    srcImage = gdal.Open(path)
    im_width = srcImage.RasterXSize  # 栅格矩阵的列数
    im_height = srcImage.RasterYSize  # 栅格矩阵的行数
    im_data = srcImage.ReadAsArray(0, 0, im_width, im_height)  # 获取数据
    '''(左上角的坐标, 像素宽度, 旋转系数, 左上角的坐标, 旋转系数, 像素高度,6)'''
    geoTrans = srcImage.GetGeoTransform()
    r = shapefile.Reader(shp_path)
    # 获取矢量文件的外接矩形,此处是投影坐标
    minX, minY, maxX, maxY = r.bbox
    WGS84 = Proj(init='EPSG:4326')
    p = Proj(proj='utm', zone=50, ellps='WGS84', preserve_units=False)
    x, y = minX, minY
    minX, minY = transform(WGS84, p, x, y)
    x, y = maxX, maxY
    maxX, maxY = transform(WGS84, p, x, y)
    # 获取影像文件的六参数
    img_X, img_Y, img_W, img_H = geoTrans[0], geoTrans[3], geoTrans[1], geoTrans[5]
    # 计算时候有三种方式，用过像素的像素、用外面相邻像素、用里面相邻的像素
    # 计算对应像素的位置，此处用最外面相邻的像素
    # 计算X的最大最小值
    minX_loc, maxX_loc = int((minX - img_X) / img_W), int((maxX - img_X) / img_W)
    # 计算Y的最大最小值
    minY_loc, maxY_loc = int((maxY - img_Y) / img_H), int((minY - img_Y) / img_H)
    clip = im_data[:, minY_loc:maxY_loc, minX_loc:maxX_loc]
    # 为图片创建一个新的geomatrix对象以便附加地理参照数据,
    geoTrans = list(geoTrans)
    geoTrans[0] = minX
    geoTrans[3] = maxY
    im_proj = srcImage.GetProjection()
    # 创建文件
    driver = gdal.GetDriverByName("GTiff")  # 数据格式
    dataset = driver.Create(out_path, clip.shape[2], clip.shape[1], clip.shape[0],
                            gdal.GDT_Float32)
    dataset.SetGeoTransform(geoTrans)  # 写入仿射变换的六个参数
    dataset.SetProjection(im_proj)  # 写入投影信息
    if clip.shape[0] == 1:
        dataset.GetRasterBand(1).WriteArray(clip)  # 写入数组数据
    else:
        for i in range(clip.shape[0]):
            dataset.GetRasterBand(i + 1).WriteArray(clip[i])
    del dataset


if __name__ == '__main__':
    root = r'G:\哨兵2号数据\2022.4.12数据集'

    ShapefilePath = os.path.join(root, '范围/202103/lx.shp')   # 面文件
    TiffPath = os.path.join(root, 'suibian_tif/suibian_2021_03.tif')   # 待裁剪的tif
    out_dir = os.path.join(root, 'Temp/temp.tif')   # 中间文件，代码跑完可以删除

    img, dataset, gcs, pcs, extend, shape = get_tif_info(TiffPath)   # 获取tif文件的投影等信息

    Dict = {1: '水', 2: '建筑', 3: '植被', 4: '裸地', 5: '稀疏植被', 6: '云'}

    sf = shapefile.Reader(ShapefilePath)   # 读取shapefile文件

    # #读取元数据
    # print(str(sf.shapeType))  # 输出shp类型
    # print(sf.encoding)# 输出shp文件编码
    # print(sf.bbox)  # 输出shp的文件范围（外包矩形）
    # print(sf.numRecords)  # 输出shp文件的要素数据
    # print(sf.records())  # 输出shp文件的属性数据
    # print(sf.shapes())  #

    for index in tqdm(range(sf.numRecords)):
        outshp = os.path.join(root, r'Temp\temp.shp')   # 中间文件，代码跑完可以删除
        index_record = sf.record(index)[0]  # 面要素的属性值
        index_shape = sf.shape(index).points  # 面要素边界
        _shape = [[[x[0], x[1]] for x in index_shape]]

        if index_record != 0:

            length = len(os.listdir(os.path.join(root, 'output', Dict[index_record])))  # 获取当前类别里面有多少条数据
            outputpath = os.path.join(root, 'output', Dict[index_record], "{}.tif".format(length + 1))

            """新建一个面文件并设置投影信息"""
            w = shapefile.Writer(outshp, shapeType=5)  # 新建数据存放位置，5代表是面文件
            w.field('FIRST_FLD')
            w.poly(_shape)
            w.record('First', 'Point')
            w.close()
            proj = osr.SpatialReference()
            proj.ImportFromProj4(str(gcs))
            wkt = proj.ExportToWkt()
            f = open(outshp.replace(".shp", ".prj"), 'w')
            f.write(wkt)
            f.close()

            """
            先使用面文件根据最小外包矩形切割tif，保存到temp.tif，这一步的目的是由于gdal.Wrap切割无法切最小外包矩形
            """
            clip_data(TiffPath, outshp, out_dir)
            time.sleep(1)

            """读shp文件"""
            sf_temp = shapefile.Reader(outshp)
            ds = gdal.Warp(outputpath,
                           out_dir,
                           format='GTiff',
                           cutlineDSName=outshp,
                           cutlineWhere="FIELD = 'whatever'",  # clip specific feature
                           dstNodata=0)  # set nodata value
            ds = None  # close dataset
